msg_tool\scripts\qlie\archive\pack/
encryption.rs

1use super::types::*;
2use crate::ext::io::*;
3use crate::scripts::base::*;
4use crate::types::*;
5use crate::utils::encoding::*;
6use crate::utils::mmx::*;
7use anyhow::Result;
8use std::io::{Read, Seek, SeekFrom, Write};
9
10pub trait Hasher {
11    fn update(&mut self, data: &[u8]) -> Result<()>;
12    fn finalize(&mut self) -> Result<u32>;
13}
14
15pub trait Encryption: std::fmt::Debug {
16    fn is_unicode(&self) -> bool {
17        false
18    }
19    fn compute_hash(&self, _data: &[u8]) -> Result<u32> {
20        Ok(0)
21    }
22    fn create_hash(&self) -> Result<Box<dyn Hasher>> {
23        Err(anyhow::anyhow!("Hasher not implemented"))
24    }
25    fn decrypt_name(&self, name: &mut [u8], hash: i32, encoding: Encoding) -> Result<String>;
26    fn decrypt_entry<'a>(
27        &self,
28        stream: Box<dyn ReadSeek + 'a>,
29        entry: &QlieEntry,
30    ) -> Result<Box<dyn ReadDebug + 'a>>;
31}
32
33pub fn create_encryption(major: u8, minor: u8) -> Result<Box<dyn Encryption>> {
34    match (major, minor) {
35        (3, 1) => Ok(Box::new(Encryption31::new())),
36        _ => Err(anyhow::anyhow!(
37            "Unsupported encryption version: {}.{}",
38            major,
39            minor
40        )),
41    }
42}
43
44pub fn decompress<'a>(data: Box<dyn ReadDebug + 'a>) -> Result<Box<dyn ReadDebug + 'a>> {
45    Ok(Box::new(Decompressor::new(data)?))
46}
47
48pub fn decrypt(data: &mut [u8], key: u32) -> Result<()> {
49    let length = data.len();
50    if length < 8 {
51        // Nothing to decrypt
52        return Ok(());
53    }
54    let mut data = MemWriterRef::new(data);
55    const C1: u64 = 0xA73C5F9D;
56    const C2: u64 = 0xCE24F523;
57    const C3: u64 = 0xFEC9753E;
58    let mut v5 = mmx_punpckldq2(C1);
59    const V7: u64 = mmx_punpckldq2(C2);
60    let mut v9 = mmx_punpckldq2(((length as u32).wrapping_add(key) as u64) ^ C3);
61    for _ in 0..length / 8 {
62        let d = data.peek_u64()?;
63        v5 = mmx_p_add_d(v5, V7) ^ v9;
64        v9 = d ^ v5;
65        data.write_u64(v9)?;
66    }
67    Ok(())
68}
69
70pub fn encrypt(data: &mut [u8], key: u32) -> Result<()> {
71    let length = data.len();
72    if length < 8 {
73        // Nothing to encrypt
74        return Ok(());
75    }
76    let mut data = MemWriterRef::new(data);
77    const C1: u64 = 0xA73C5F9D;
78    const C2: u64 = 0xCE24F523;
79    const C3: u64 = 0xFEC9753E;
80    let mut v5 = mmx_punpckldq2(C1);
81    const V7: u64 = mmx_punpckldq2(C2);
82    let mut v9 = mmx_punpckldq2(((length as u32).wrapping_add(key) as u64) ^ C3);
83    for _ in 0..length / 8 {
84        let mut d = data.peek_u64()?;
85        v5 = mmx_p_add_d(v5, V7) ^ v9;
86        v9 = d;
87        d ^= v5;
88        data.write_u64(d)?;
89    }
90    Ok(())
91}
92
93pub fn get_common_key(data: &[u8]) -> Result<Vec<u8>> {
94    let mut reader = MemReaderRef::new(data);
95    let mut key = vec![0u8; 0x400];
96    let mut writer = MemWriterRef::new(&mut key);
97    for i in 0..0x100i32 {
98        let temp = if (i % 3) != 0 {
99            (i + 7) * -(i + 3)
100        } else {
101            (i + 7) * (i + 3)
102        };
103        writer.write_u32_at(i as u64 * 4, temp as u32)?;
104    }
105    let mut v1 = (reader.peek_u8_at(49)? % 0x49) as i32 + 0x80;
106    let v2 = (reader.peek_u8_at(79)? % 7) as i32 + 7;
107    let data_len = data.len() as i32;
108    for i in 0..0x400 {
109        v1 = (v1.wrapping_add(v2)) % data_len;
110        key[i] ^= reader.peek_u8_at(v1 as u64)?;
111    }
112    // crate::utils::files::write_file("./testscripts/test.bin")?.write_all(&key)?;
113    Ok(key)
114}
115
116#[derive(Debug)]
117pub struct Encryption31 {}
118
119impl Encryption31 {
120    pub fn new() -> Self {
121        Self {}
122    }
123
124    fn create_table(length: usize, mut value: u32, is_v1: bool) -> Result<Vec<u8>> {
125        let mut table = Vec::with_capacity(length);
126        let key: u32 = if is_v1 { 0x8DF21431 } else { 0x8A77F473 };
127        for _ in 0..length {
128            let t = (key as u64).wrapping_mul((value as u64) ^ (key as u64));
129            value = ((t >> 32) + t) as u32;
130            table.push(value);
131        }
132        let mut mem = MemWriter::with_capacity(length * 4);
133        for i in table {
134            mem.write_u32(i)?;
135        }
136        Ok(mem.into_inner())
137    }
138
139    pub fn compute_name_hash(&self, name: &[u16]) -> Result<u32> {
140        let mut v2 = 0u32;
141        let mut v3 = name.len() as u32;
142        let mut v4 = 1u32;
143        if v3 > 0 {
144            loop {
145                let n = (name[(v4 - 1) as usize] as u32) << (v4 & 7);
146                v2 = v2.wrapping_add(n) & 0x3FFFFFFF;
147                v4 += 1;
148                v3 -= 1;
149                if v3 == 0 {
150                    break;
151                }
152            }
153        }
154        Ok(v2)
155    }
156
157    pub fn encrypt_name(&self, name: &mut [u8], hash: i32) -> Result<()> {
158        if name.len() % 2 != 0 {
159            return Err(anyhow::anyhow!(
160                "Invalid name length for Unicode encryption"
161            ));
162        }
163        let char_len = name.len() / 2;
164        let cl = char_len as i32;
165        let temp = (cl.wrapping_mul(cl) ^ cl ^ 0x3e13 ^ (hash >> 16) ^ hash) & 0xFFFF;
166        let mut key = temp;
167        for i in 0..char_len {
168            key = temp
169                .wrapping_add(i as i32)
170                .wrapping_add(key.wrapping_mul(8));
171            name[i * 2] ^= key as u8;
172            name[i * 2 + 1] ^= (key >> 8) as u8;
173        }
174        Ok(())
175    }
176}
177
178impl Encryption for Encryption31 {
179    fn is_unicode(&self) -> bool {
180        true
181    }
182    fn compute_hash(&self, data: &[u8]) -> Result<u32> {
183        let mut hasher = Encryption31Hasher::new();
184        hasher.update(data)?;
185        Ok(hasher.finalize()?)
186    }
187    fn create_hash(&self) -> Result<Box<dyn Hasher>> {
188        Ok(Box::new(Encryption31Hasher::new()))
189    }
190    fn decrypt_name(&self, name: &mut [u8], hash: i32, _encoding: Encoding) -> Result<String> {
191        if name.len() % 2 != 0 {
192            return Err(anyhow::anyhow!(
193                "Invalid name length for Unicode decryption"
194            ));
195        }
196        let char_len = name.len() / 2;
197        let cl = char_len as i32;
198        let temp = (cl.wrapping_mul(cl) ^ cl ^ 0x3e13 ^ (hash >> 16) ^ hash) & 0xFFFF;
199        let mut key = temp;
200        for i in 0..char_len {
201            key = temp
202                .wrapping_add(i as i32)
203                .wrapping_add(key.wrapping_mul(8));
204            name[i * 2] ^= key as u8;
205            name[i * 2 + 1] ^= (key >> 8) as u8;
206        }
207        Ok(decode_to_string(Encoding::Utf16LE, &name, true)?)
208    }
209    fn decrypt_entry<'a>(
210        &self,
211        stream: Box<dyn ReadSeek + 'a>,
212        entry: &QlieEntry,
213    ) -> Result<Box<dyn ReadDebug + 'a>> {
214        match entry.is_encrypted {
215            // No encryption
216            0 => Ok(Box::new(stream)),
217            1 => Ok(Box::new(Encryption31DecryptV1::new(
218                stream,
219                entry.size,
220                entry.name.clone(),
221                entry.key,
222            )?)),
223            2 => Ok(Box::new(Encryption31DecryptV2::new(
224                stream,
225                entry.size,
226                entry.name.clone(),
227                entry.key,
228                entry
229                    .common_key
230                    .clone()
231                    .ok_or(anyhow::anyhow!("Missing common key"))?,
232            )?)),
233            _ => Err(anyhow::anyhow!(
234                "Unsupported encryption flag: {}",
235                entry.is_encrypted
236            )),
237        }
238    }
239}
240
241pub struct Encryption31Hasher {
242    hash: u64,
243    key: u64,
244    buffer: [u8; 8],
245    buffer_len: usize,
246}
247
248impl Encryption31Hasher {
249    pub fn new() -> Self {
250        Self {
251            hash: 0,
252            key: 0,
253            buffer: [0; 8],
254            buffer_len: 0,
255        }
256    }
257
258    fn update_internal(&mut self, data: u64) {
259        const C: u64 = mmx_punpckldq2(0xA35793A7);
260        self.hash = mmx_p_add_w(self.hash, C);
261        let temp = mmx_p_add_w(self.key, self.hash ^ data);
262        self.key = mmx_p_sll_d(temp, 3) | mmx_p_srl_d(temp, 0x1d);
263    }
264}
265
266impl Hasher for Encryption31Hasher {
267    fn update(&mut self, data: &[u8]) -> Result<()> {
268        let mut used = 0;
269        if self.buffer_len > 0 {
270            let to_copy = (8 - self.buffer_len).min(data.len());
271            self.buffer[self.buffer_len..self.buffer_len + to_copy]
272                .copy_from_slice(&data[..to_copy]);
273            self.buffer_len += to_copy;
274            used += to_copy;
275        }
276        if self.buffer_len == 8 {
277            let v = u64::from_le_bytes(self.buffer);
278            self.update_internal(v);
279            self.buffer_len = 0;
280        }
281        let round = (data.len() - used) / 8;
282        let mut reader = MemReaderRef::new(&data[used..]);
283        for _ in 0..round {
284            let v = reader.read_u64()?;
285            self.update_internal(v);
286            used += 8;
287        }
288        let remaining = data.len() - used;
289        if remaining > 0 {
290            self.buffer[..remaining].copy_from_slice(&data[used..]);
291            self.buffer_len = remaining;
292        }
293        Ok(())
294    }
295
296    fn finalize(&mut self) -> Result<u32> {
297        let p1 = ((self.key as i16) as i32).wrapping_mul(((self.key >> 32) as i16) as i32);
298        let p2 = (((self.key >> 16) as i16) as i32).wrapping_mul(((self.key >> 48) as i16) as i32);
299        Ok((p1.wrapping_add(p2)) as u32)
300    }
301}
302
303#[derive(Debug)]
304struct Encryption31DecryptV1<'a> {
305    stream: Box<dyn ReadSeek + 'a>,
306    table: MemReader,
307    v4: u32,
308    v6: u64,
309}
310
311impl<'a> Encryption31DecryptV1<'a> {
312    pub fn new(
313        stream: Box<dyn ReadSeek + 'a>,
314        size: u32,
315        name: String,
316        key: u32,
317    ) -> Result<AlignedReader<8, Self>> {
318        let mut v1 = 0x85F532u32;
319        let mut v2 = 0x33F641u32;
320        for (i, n) in name.encode_utf16().enumerate() {
321            v1 = v1.wrapping_add((n as u32) << (i & 7));
322            v2 ^= v1;
323        }
324        v2 = v2.wrapping_add(
325            key ^ ((7 * (size & 0xFFFFFF))
326                .wrapping_add(size)
327                .wrapping_add(v1)
328                .wrapping_add(v1 ^ size ^ 0x8F32DC)),
329        );
330        v2 = 9 * (v2 & 0xFFFFFF);
331        let table = MemReader::new(Encryption31::create_table(0x40, v2, true)?);
332        let v4 = 8 * (table.cpeek_u32_at(52)? & 0xF);
333        let v6 = table.cpeek_u64_at(24)?;
334        let inner = Self {
335            stream,
336            table,
337            v4,
338            v6,
339        };
340        Ok(AlignedReader::new(inner))
341    }
342}
343
344impl<'a> Read for Encryption31DecryptV1<'a> {
345    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
346        let readed = self.stream.read_most(buf)?;
347        let round = readed / 8;
348        let mut writer = MemWriterRef::new(buf);
349        for _ in 0..round {
350            let d = writer.peek_u64()?;
351            let temp = self.table.cpeek_u64_at(self.v4 as u64)?;
352            let v7 = mmx_p_add_d(self.v6 ^ temp, temp);
353            let v8 = d ^ v7;
354            writer.write_u64(v8)?;
355            self.v6 = mmx_p_add_w(mmx_p_sll_d(mmx_p_add_b(v7, v8) ^ v8, 1), v8);
356            self.v4 = (self.v4 + 8) & 0x7F;
357        }
358        Ok(readed)
359    }
360}
361
362#[derive(Debug)]
363pub struct Encryption31EncryptV1<T: Write> {
364    stream: T,
365    table: MemReader,
366    v4: u32,
367    v6: u64,
368}
369
370impl<T: Write> Encryption31EncryptV1<T> {
371    pub fn new(stream: T, size: u32, name: String, key: u32) -> Result<AlignedWriter<8, Self>> {
372        let mut v1 = 0x85F532u32;
373        let mut v2 = 0x33F641u32;
374        for (i, n) in name.encode_utf16().enumerate() {
375            v1 = v1.wrapping_add((n as u32) << (i & 7));
376            v2 ^= v1;
377        }
378        v2 = v2.wrapping_add(
379            key ^ ((7 * (size & 0xFFFFFF))
380                .wrapping_add(size)
381                .wrapping_add(v1)
382                .wrapping_add(v1 ^ size ^ 0x8F32DC)),
383        );
384        v2 = 9 * (v2 & 0xFFFFFF);
385        let table = MemReader::new(Encryption31::create_table(0x40, v2, true)?);
386        let v4 = 8 * (table.cpeek_u32_at(52)? & 0xF);
387        let v6 = table.cpeek_u64_at(24)?;
388        let inner = Self {
389            stream,
390            table,
391            v4,
392            v6,
393        };
394        Ok(AlignedWriter::new(inner))
395    }
396}
397
398impl<T: Write> Write for Encryption31EncryptV1<T> {
399    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
400        let round = buf.len() / 8;
401        let mut reader = MemReaderRef::new(buf);
402        for _ in 0..round {
403            let d = reader.read_u64()?;
404            let temp = self.table.cpeek_u64_at(self.v4 as u64)?;
405            let v7 = mmx_p_add_d(self.v6 ^ temp, temp);
406            let v8 = d ^ v7;
407            self.stream.write_u64(v8)?;
408            self.v6 = mmx_p_add_w(mmx_p_sll_d(mmx_p_add_b(v7, d) ^ d, 1), d);
409            self.v4 = (self.v4 + 8) & 0x7F;
410        }
411        let remain = buf.len() % 8;
412        if remain > 0 {
413            self.stream.write_all(&buf[buf.len() - remain..])?;
414        }
415        Ok(buf.len())
416    }
417
418    fn flush(&mut self) -> std::io::Result<()> {
419        self.stream.flush()
420    }
421}
422
423#[derive(Debug)]
424struct Encryption31DecryptV2<'a> {
425    stream: Box<dyn ReadSeek + 'a>,
426    table: MemReader,
427    v4: u32,
428    v6: u64,
429    common_key: MemReader,
430}
431
432impl<'a> Encryption31DecryptV2<'a> {
433    pub fn new(
434        stream: Box<dyn ReadSeek + 'a>,
435        size: u32,
436        name: String,
437        key: u32,
438        common_key: Vec<u8>,
439    ) -> Result<AlignedReader<8, Self>> {
440        let mut v1 = 0x86F7E2u32;
441        let mut v2 = 0x4437F1u32;
442        for (i, n) in name.encode_utf16().enumerate() {
443            v1 = v1.wrapping_add((n as u32) << (i & 7));
444            v2 ^= v1;
445        }
446        v2 = v2.wrapping_add(
447            key ^ ((13 * (size & 0xFFFFFF))
448                .wrapping_add(size)
449                .wrapping_add(v1)
450                .wrapping_add(v1 ^ size ^ 0x56E213)),
451        );
452        v2 = 13 * (v2 & 0xFFFFFF);
453        let table = MemReader::new(Encryption31::create_table(0x40, v2, false)?);
454        let v4 = 8 * (table.cpeek_u32_at(32)? & 0xD);
455        let v6 = table.cpeek_u64_at(24)?;
456        let inner = Self {
457            stream,
458            table,
459            v4,
460            v6,
461            common_key: MemReader::new(common_key),
462        };
463        Ok(AlignedReader::new(inner))
464    }
465}
466
467impl<'a> Read for Encryption31DecryptV2<'a> {
468    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
469        let readed = self.stream.read_most(buf)?;
470        let round = readed / 8;
471        let mut writer = MemWriterRef::new(buf);
472        for _ in 0..round {
473            let d = writer.peek_u64()?;
474            let temp_index1 = ((self.v4 & 0xF) * 8) as u64;
475            let temp_index2 = ((self.v4 & 0x7F) * 8) as u64;
476            let temp = self.table.cpeek_u64_at(temp_index1)?
477                ^ self.common_key.cpeek_u64_at(temp_index2)?;
478            let v7 = mmx_p_add_d(self.v6 ^ temp, temp);
479            let v8 = d ^ v7;
480            writer.write_u64(v8)?;
481            self.v6 = mmx_p_add_w(mmx_p_sll_d(mmx_p_add_b(v7, v8) ^ v8, 1), v8);
482            self.v4 = (self.v4 + 1) & 0x7F;
483        }
484        Ok(readed)
485    }
486}
487
488#[derive(Debug)]
489pub struct Encryption31EncryptV2<T: Write> {
490    stream: T,
491    table: MemReader,
492    v4: u32,
493    v6: u64,
494    common_key: MemReader,
495}
496
497impl<T: Write> Encryption31EncryptV2<T> {
498    pub fn new(
499        stream: T,
500        size: u32,
501        name: String,
502        key: u32,
503        common_key: Vec<u8>,
504    ) -> Result<AlignedWriter<8, Self>> {
505        let mut v1 = 0x86F7E2u32;
506        let mut v2 = 0x4437F1u32;
507        for (i, n) in name.encode_utf16().enumerate() {
508            v1 = v1.wrapping_add((n as u32) << (i & 7));
509            v2 ^= v1;
510        }
511        v2 = v2.wrapping_add(
512            key ^ ((13 * (size & 0xFFFFFF))
513                .wrapping_add(size)
514                .wrapping_add(v1)
515                .wrapping_add(v1 ^ size ^ 0x56E213)),
516        );
517        v2 = 13 * (v2 & 0xFFFFFF);
518        let table = MemReader::new(Encryption31::create_table(0x40, v2, false)?);
519        let v4 = 8 * (table.cpeek_u32_at(32)? & 0xD);
520        let v6 = table.cpeek_u64_at(24)?;
521        let inner = Self {
522            stream,
523            table,
524            v4,
525            v6,
526            common_key: MemReader::new(common_key),
527        };
528        Ok(AlignedWriter::new(inner))
529    }
530}
531
532impl<T: Write> Write for Encryption31EncryptV2<T> {
533    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
534        let round = buf.len() / 8;
535        let mut reader = MemReaderRef::new(buf);
536        for _ in 0..round {
537            let d = reader.read_u64()?;
538            let temp_index1 = ((self.v4 & 0xF) * 8) as u64;
539            let temp_index2 = ((self.v4 & 0x7F) * 8) as u64;
540            let temp = self.table.cpeek_u64_at(temp_index1)?
541                ^ self.common_key.cpeek_u64_at(temp_index2)?;
542            let v7 = mmx_p_add_d(self.v6 ^ temp, temp);
543            let v8 = d ^ v7;
544            self.stream.write_u64(v8)?;
545            self.v6 = mmx_p_add_w(mmx_p_sll_d(mmx_p_add_b(v7, d) ^ d, 1), d);
546            self.v4 = (self.v4 + 1) & 0x7F;
547        }
548        let remain = buf.len() % 8;
549        if remain > 0 {
550            self.stream.write_all(&buf[buf.len() - remain..])?;
551        }
552        Ok(buf.len())
553    }
554
555    fn flush(&mut self) -> std::io::Result<()> {
556        self.stream.flush()
557    }
558}
559
560#[derive(Debug)]
561pub struct Decompressor<'a> {
562    stream: Box<dyn ReadDebug + 'a>,
563    is_16bit: bool,
564    temp: Vec<u8>,
565    buf: Vec<u8>,
566    buf_pos: usize,
567}
568
569impl<'a> Decompressor<'a> {
570    pub fn new(mut stream: Box<dyn ReadDebug + 'a>) -> Result<Self> {
571        let sign = stream.read_u32()?;
572        if sign != 0xFF435031 {
573            return Err(anyhow::anyhow!("Invalid compression signature"));
574        }
575        let is_16bit = stream.read_u32()? & 1 != 0;
576        let _unpacked_size = stream.read_u32()?;
577        let temp = vec![0u8; 0x1000];
578        Ok(Self {
579            stream,
580            is_16bit,
581            temp,
582            buf: Vec::new(),
583            buf_pos: 0,
584        })
585    }
586
587    fn next_block(&mut self) -> Result<()> {
588        self.buf.clear();
589        self.buf_pos = 0;
590        let mut buf = [0u8; 1];
591        let readed = self.stream.read(&mut buf)?;
592        if readed == 0 {
593            return Ok(());
594        }
595        let mut buf_used = false;
596        let mut table = [[0u8; 2]; 0x100];
597        let mut i = 0u32;
598        while i < 0x100 {
599            let mut c = if !buf_used {
600                buf_used = true;
601                buf[0] as u32
602            } else {
603                self.stream.read_u8()? as u32
604            };
605            if c > 127 {
606                c -= 127;
607                while c > 0 {
608                    table[i as usize][0] = i as u8;
609                    c -= 1;
610                    i += 1;
611                }
612            }
613            c += 1;
614            while c > 0 && i < 0x100 {
615                table[i as usize][0] = self.stream.read_u8()?;
616                if i as u8 != table[i as usize][0] {
617                    table[i as usize][1] = self.stream.read_u8()?;
618                }
619                c -= 1;
620                i += 1;
621            }
622        }
623        let mut block_size = if self.is_16bit {
624            self.stream.read_u16()? as usize
625        } else {
626            self.stream.read_u32()? as usize
627        };
628        let mut temp_length = 0usize;
629        while block_size > 0 || temp_length > 0 {
630            let c = if temp_length > 0 {
631                temp_length -= 1;
632                self.temp[temp_length]
633            } else {
634                block_size -= 1;
635                self.stream.read_u8()?
636            };
637            if c == table[c as usize][0] {
638                self.buf.push(c);
639            } else {
640                self.temp[temp_length] = table[c as usize][1];
641                temp_length += 1;
642                self.temp[temp_length] = table[c as usize][0];
643                temp_length += 1;
644            }
645        }
646        Ok(())
647    }
648}
649
650impl<'a> Read for Decompressor<'a> {
651    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
652        let mut used = 0;
653        while used < buf.len() {
654            if self.buf_pos >= self.buf.len() {
655                self.next_block()
656                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
657                if self.buf.is_empty() {
658                    break;
659                }
660            }
661            let to_copy = (self.buf.len() - self.buf_pos).min(buf.len() - used);
662            buf[used..used + to_copy]
663                .copy_from_slice(&self.buf[self.buf_pos..self.buf_pos + to_copy]);
664            self.buf_pos += to_copy;
665            used += to_copy;
666        }
667        Ok(used)
668    }
669}
670
671pub struct Compressor<W: Write + Seek> {
672    stream: W,
673    buffer: Vec<u8>,
674    total_unpacked_size: u32,
675    is_finished: bool,
676}
677
678impl<W: Write + Seek> Compressor<W> {
679    pub fn new(mut stream: W) -> Result<Self> {
680        stream.write_u32(0xFF435031)?;
681        stream.write_u32(0)?;
682        stream.write_u32(0)?;
683        Ok(Self {
684            stream,
685            buffer: Vec::new(),
686            total_unpacked_size: 0,
687            is_finished: false,
688        })
689    }
690
691    pub fn finish(&mut self) -> Result<()> {
692        if self.is_finished {
693            return Ok(());
694        }
695        if !self.buffer.is_empty() {
696            self.flush_block()?;
697        }
698        let pos = self.stream.stream_position()?;
699        self.stream.seek(SeekFrom::Start(8))?;
700        self.stream.write_u32(self.total_unpacked_size)?;
701        self.stream.seek(SeekFrom::Start(pos))?;
702        self.is_finished = true;
703        Ok(())
704    }
705
706    fn flush_block(&mut self) -> Result<()> {
707        if self.buffer.is_empty() {
708            return Ok(());
709        }
710        let (table, data) = compress_algo(&self.buffer);
711
712        // Write table
713        write_table(&mut self.stream, &table)?;
714
715        // Write block size
716        self.stream.write_u32(data.len() as u32)?;
717        self.stream.write_all(&data)?;
718
719        self.total_unpacked_size += self.buffer.len() as u32;
720        self.buffer.clear();
721        Ok(())
722    }
723}
724
725impl<W: Write + Seek> Write for Compressor<W> {
726    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
727        let mut pos = 0;
728        while pos < buf.len() {
729            let space = 0x10000 - self.buffer.len();
730            let copy = space.min(buf.len() - pos);
731            self.buffer.extend_from_slice(&buf[pos..pos + copy]);
732            pos += copy;
733            if self.buffer.len() >= 0x10000 {
734                self.flush_block()
735                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
736            }
737        }
738        Ok(buf.len())
739    }
740
741    fn flush(&mut self) -> std::io::Result<()> {
742        self.stream.flush()
743    }
744}
745
746impl<W: Write + Seek> Drop for Compressor<W> {
747    fn drop(&mut self) {
748        let _ = self.finish();
749    }
750}
751
752fn write_table<W: Write>(writer: &mut W, table: &[[u8; 2]; 256]) -> Result<()> {
753    let mut i = 0;
754    while i < 256 {
755        // Count consecutive identities
756        let mut n_identities = 0;
757        let mut j = i;
758        while j < 256 && table[j][0] == j as u8 {
759            n_identities += 1;
760            j += 1;
761        }
762
763        if n_identities > 0 {
764            let k = n_identities.min(128);
765            if i + k == 256 {
766                writer.write_u8(127 + k as u8)?;
767                i += k;
768            } else {
769                writer.write_u8(127 + k as u8)?;
770                i += k;
771                // Write explicit
772                writer.write_u8(table[i][0])?;
773                if table[i][0] != i as u8 {
774                    writer.write_u8(table[i][1])?;
775                }
776                i += 1;
777            }
778        } else {
779            let mut count = 0;
780            let mut j = i;
781            while j < 256 && count < 128 {
782                if j + 1 < 256 && table[j][0] == j as u8 && table[j + 1][0] == (j + 1) as u8 {
783                    break;
784                }
785                count += 1;
786                j += 1;
787            }
788
789            writer.write_u8((count - 1) as u8)?;
790            for k in 0..count {
791                let curr = i + k;
792                writer.write_u8(table[curr][0])?;
793                if table[curr][0] != curr as u8 {
794                    writer.write_u8(table[curr][1])?;
795                }
796            }
797            i += count;
798        }
799    }
800    Ok(())
801}
802
803fn compress_algo(input: &[u8]) -> ([[u8; 2]; 256], Vec<u8>) {
804    let mut tokens = input.to_vec();
805    let mut table = [[0u8; 2]; 256];
806    for i in 0..256 {
807        table[i][0] = i as u8;
808    }
809
810    let max_iterations = 256;
811    for _ in 0..max_iterations {
812        let mut pair_counts = vec![0u32; 65536];
813        let mut max_pair_idx = 0;
814        let mut max_pair_count = 0;
815
816        if tokens.len() < 2 {
817            break;
818        }
819
820        for i in 0..tokens.len() - 1 {
821            let pair = ((tokens[i] as usize) << 8) | (tokens[i + 1] as usize);
822            pair_counts[pair] += 1;
823            if pair_counts[pair] > max_pair_count {
824                max_pair_count = pair_counts[pair];
825                max_pair_idx = pair;
826            }
827        }
828
829        // Must appear at least twice to save space (2 bytes * 2 -> 1 byte * 2 + overhead)
830        if max_pair_count < 2 {
831            break;
832        }
833
834        let is_used = get_used_tokens(&tokens, &table);
835        let mut unused = None;
836        for i in 0..256 {
837            if !is_used[i] {
838                unused = Some(i as u8);
839                break;
840            }
841        }
842
843        if let Some(token) = unused {
844            let left = (max_pair_idx >> 8) as u8;
845            let right = (max_pair_idx & 0xFF) as u8;
846
847            table[token as usize] = [left, right];
848
849            let mut new_tokens = Vec::with_capacity(tokens.len());
850            let mut i = 0;
851            while i < tokens.len() {
852                if i + 1 < tokens.len() && tokens[i] == left && tokens[i + 1] == right {
853                    new_tokens.push(token);
854                    i += 2;
855                } else {
856                    new_tokens.push(tokens[i]);
857                    i += 1;
858                }
859            }
860            tokens = new_tokens;
861        } else {
862            break;
863        }
864    }
865    (table, tokens)
866}
867
868fn get_used_tokens(tokens: &[u8], table: &[[u8; 2]; 256]) -> [bool; 256] {
869    let mut used = [false; 256];
870    let mut stack = Vec::with_capacity(256);
871
872    // Mark direct tokens
873    for &t in tokens {
874        if !used[t as usize] {
875            used[t as usize] = true;
876            stack.push(t);
877        }
878    }
879
880    // Propagate
881    while let Some(t) = stack.pop() {
882        // If t is composite, mark children
883        // Check if t is composite: table[t][0] != t
884        let t_idx = t as usize;
885        if table[t_idx][0] != t {
886            let l = table[t_idx][0];
887            let r = table[t_idx][1];
888
889            if !used[l as usize] {
890                used[l as usize] = true;
891                stack.push(l);
892            }
893            if !used[r as usize] {
894                used[r as usize] = true;
895                stack.push(r);
896            }
897        }
898    }
899    used
900}
901
902pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
903    let mut cursor = std::io::Cursor::new(Vec::new());
904    {
905        let mut compressor = Compressor::new(&mut cursor)?;
906        compressor.write_all(data)?;
907        compressor.finish()?;
908    }
909    Ok(cursor.into_inner())
910}
911
912#[test]
913fn test_compress_decompress() -> Result<()> {
914    let data = b"The quick brown fox jumps over the lazy dog.".repeat(100);
915    println!("Original size: {}", data.len());
916    let compressed = compress(&data)?;
917    println!("Compressed size: {}", compressed.len());
918    let mut decompressed = decompress(Box::new(MemReaderRef::new(&compressed)))?;
919    let mut output = Vec::new();
920    decompressed.read_to_end(&mut output)?;
921    assert_eq!(data.as_slice(), output.as_slice());
922    Ok(())
923}